Dirichlet–Multinomial Distribution (dirichlet_multinomial)#
The Dirichlet–multinomial (a.k.a. Dirichlet compound multinomial) is a discrete multivariate distribution over count vectors. It appears when you model category probabilities as uncertain: draw probabilities \(p\) from a Dirichlet distribution, then draw counts \(X\) from a multinomial given \(p\).
Learning goals#
By the end you should be able to:
explain the Dirichlet–multinomial as a “multinomial with random probabilities” and why it captures overdispersion
write down the PMF and understand its constraints (support + parameter space)
derive the mean and covariance from the hierarchical model
sample from it in NumPy and visualize it (1D and simplex plots)
use SciPy’s
scipy.stats.dirichlet_multinomialfor PMF/moments, and implement missing pieces (CDF/sampling/fit) yourself
Prerequisites#
basic probability (expectation, variance, conditional expectation)
familiarity with the multinomial and Dirichlet distributions
comfort reading Gamma/Beta functions
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from scipy import stats
from scipy.optimize import minimize
from scipy.special import digamma, gammaln, logsumexp
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
rng = np.random.default_rng(7)
np.set_printoptions(precision=4, suppress=True)
## 1) Title & Classification
- **Name**: `dirichlet_multinomial` (Dirichlet–multinomial, Dirichlet compound multinomial)
- **Type**: **Discrete** (multivariate counts)
- **Support** (for $K$ categories and total count $n$):
$$
\mathcal{S}_{n,K} = \left\{x \in \{0,1,2,\dots\}^K : \sum_{i=1}^K x_i = n\right\}
$$
- **Parameter space**:
- $n \in \{0,1,2,\dots\}$ (integer total count)
- $\alpha = (\alpha_1,\dots,\alpha_K)$ with $\alpha_i > 0$
- define $\alpha_0 = \sum_{i=1}^K \alpha_i$ (total concentration)
A draw $X \sim \text{DirichletMultinomial}(n,\alpha)$ is a **count vector** with a fixed total: $\sum_i X_i = n$.
## 2) Intuition & Motivation
### What it models
The Dirichlet–multinomial models **counts across categories** when the category probabilities themselves vary across trials/replicates.
A common hierarchical story is:
$$
p \sim \mathrm{Dirichlet}(\alpha),
\qquad
X \mid p \sim \mathrm{Multinomial}(n, p).
$$
If $p$ were fixed, you’d just have a multinomial. But if $p$ changes (e.g., each document has a different word distribution), the multinomial is often **too confident**.
The Dirichlet–multinomial captures this extra variability (“**overdispersion**”) and induces **negative correlations** between counts because they must sum to $n$.
### Typical real-world use cases
- **Text / NLP**: bag-of-words counts (posterior predictive of a Dirichlet–multinomial model)
- **Ecology**: species counts across sites with heterogeneous composition
- **Genomics**: overdispersed categorical counts (e.g., allelic counts)
- **A/B testing on categories**: uncertainty in category probabilities across cohorts
### Relations to other distributions
- **Dirichlet + Multinomial**: it is the **Dirichlet mixture** of multinomials.
- **Beta–binomial**: when $K=2$, the first component $X_1$ is Beta–binomial.
- **Multinomial limit**: as $\alpha_0 \to \infty$ with $\alpha/\alpha_0$ fixed, the Dirichlet–multinomial approaches a multinomial with fixed probabilities.
- **Pólya urn**: an equivalent sampling scheme is “reinforcement” sampling where each draw increases the chance of drawing that category again.
## 3) Formal Definition
Let $X=(X_1,\dots,X_K)$ be a count vector with $\sum_i X_i = n$.
### PMF (discrete analogue of a PDF)
For $x \in \mathcal{S}_{n,K}$:
$$
\Pr(X=x \mid n,\alpha)
= \frac{n!}{\prod_{i=1}^K x_i!}
\frac{\Gamma(\alpha_0)}{\Gamma(\alpha_0+n)}
\prod_{i=1}^K \frac{\Gamma(\alpha_i + x_i)}{\Gamma(\alpha_i)}.
$$
Using the rising factorial (Pochhammer symbol) $(a)_m = \Gamma(a+m)/\Gamma(a)$, this can be written:
$$
\Pr(X=x \mid n,\alpha)
= \frac{n!}{\prod_i x_i!}
\frac{\prod_i (\alpha_i)_{x_i}}{(\alpha_0)_n}.
$$
### CDF
A common multivariate “CDF” is the **lower-orthant CDF**:
$$
F(x) = \Pr(X_1 \le x_1,\dots,X_K \le x_K)
= \sum_{y \in \mathcal{S}_{n,K}:\; y_i \le x_i\;\forall i} \Pr(X=y).
$$
There is no simple closed form in general. For $K=2$, this reduces to the usual **univariate CDF** of the Beta–binomial distribution.
def _validate_alpha(alpha) -> np.ndarray:
alpha = np.asarray(alpha, dtype=float)
if alpha.ndim != 1:
raise ValueError("alpha must be a 1D array of positive values")
if alpha.size < 2:
raise ValueError("alpha must have length K>=2")
if not np.all(np.isfinite(alpha)):
raise ValueError("alpha must be finite")
if np.any(alpha <= 0):
raise ValueError("alpha must be strictly positive")
return alpha
def _validate_counts(x, k: int) -> np.ndarray:
x = np.asarray(x)
if x.ndim == 1:
x = x[None, :]
if x.ndim != 2 or x.shape[1] != k:
raise ValueError(f"x must have shape (k,) or (m,k) with k={k}")
if not np.issubdtype(x.dtype, np.integer):
if np.any(np.abs(x - np.round(x)) > 0):
raise ValueError("x must contain integers")
x = np.round(x).astype(int)
else:
x = x.astype(int)
if np.any(x < 0):
raise ValueError("x must be nonnegative")
return x
def dirichlet_multinomial_logpmf(x, alpha, n: int | None = None) -> np.ndarray:
'''Vectorized log PMF for the Dirichlet–multinomial.
Parameters
----------
x:
Count vector(s), shape (k,) or (m,k). Each row must sum to n.
alpha:
Concentration parameters (k,), alpha_i > 0.
n:
Total count. If None, inferred from x row sums.
'''
alpha = _validate_alpha(alpha)
x = _validate_counts(x, k=alpha.size)
row_sums = x.sum(axis=1)
if n is None:
n_vec = row_sums
else:
if np.any(row_sums != n):
raise ValueError("Each row of x must sum to n")
n_vec = np.full_like(row_sums, fill_value=n)
alpha0 = alpha.sum()
log_multinomial_coeff = gammaln(n_vec + 1) - np.sum(gammaln(x + 1), axis=1)
log_norm = gammaln(alpha0) - gammaln(alpha0 + n_vec)
log_ratio = np.sum(gammaln(alpha + x) - gammaln(alpha), axis=1)
out = log_multinomial_coeff + log_norm + log_ratio
return out[0] if out.size == 1 else out
def dirichlet_multinomial_pmf(x, alpha, n: int | None = None) -> np.ndarray:
return np.exp(dirichlet_multinomial_logpmf(x, alpha=alpha, n=n))
def compositions(n: int, k: int):
'''Generate all k-tuples of nonnegative integers summing to n (stars and bars).'''
if k == 1:
yield (n,)
return
for i in range(n + 1):
for tail in compositions(n - i, k - 1):
yield (i,) + tail
def enumerate_support(n: int, k: int) -> np.ndarray:
'''Enumerate the support S_{n,k}. Size is comb(n+k-1, k-1).'''
return np.array(list(compositions(n, k)), dtype=int)
def dm_cdf_small_n(x, alpha, n: int) -> float:
'''Lower-orthant CDF by brute-force summation (only feasible for small n,k).'''
alpha = _validate_alpha(alpha)
x = _validate_counts(x, k=alpha.size)[0]
if x.sum() != n:
raise ValueError("x must sum to n")
ys = enumerate_support(n=n, k=alpha.size)
mask = np.all(ys <= x[None, :], axis=1)
return float(np.sum(dirichlet_multinomial_pmf(ys[mask], alpha=alpha, n=n)))
def simplex_xy_3(counts: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
'''Map 3-category compositions to 2D barycentric coordinates for plotting.'''
counts = np.asarray(counts, dtype=float)
counts = np.atleast_2d(counts)
if counts.shape[1] != 3:
raise ValueError("simplex_xy_3 expects shape (m,3)")
n = counts.sum(axis=1)
if np.any(n <= 0):
raise ValueError("All rows must sum to a positive n")
p = counts / n[:, None]
x = p[:, 1] + 0.5 * p[:, 2]
y = (np.sqrt(3) / 2.0) * p[:, 2]
return x, y
def dirichlet_rvs_numpy(alpha, size: int, rng: np.random.Generator) -> np.ndarray:
'''Sample Dirichlet(alpha) via Gamma normalization (NumPy-only).'''
alpha = _validate_alpha(alpha)
g = rng.gamma(shape=alpha, scale=1.0, size=(size, alpha.size))
return g / g.sum(axis=1, keepdims=True)
def dirichlet_multinomial_rvs_numpy(alpha, n: int, size: int, rng: np.random.Generator) -> np.ndarray:
'''Sample Dirichlet–multinomial(n, alpha) (NumPy-only).
Algorithm:
1) p ~ Dirichlet(alpha)
2) X | p ~ Multinomial(n, p)
'''
alpha = _validate_alpha(alpha)
ps = dirichlet_rvs_numpy(alpha, size=size, rng=rng)
out = np.empty((size, alpha.size), dtype=int)
for i, p in enumerate(ps):
out[i] = rng.multinomial(n, p)
return out
# Quick sanity check against SciPy's PMF
alpha = np.array([1.0, 2.0, 3.0])
n = 10
x = np.array([2, 3, 5])
pmf_numpy = dirichlet_multinomial_pmf(x, alpha=alpha, n=n)
pmf_scipy = stats.dirichlet_multinomial.pmf(x, alpha=alpha, n=n)
pmf_numpy, pmf_scipy, float(pmf_numpy - pmf_scipy)
(0.027972027972027885, 0.02797202797202796, -7.632783294297951e-17)
## 4) Moments & Properties
### Mean
Using the hierarchical model $p \sim \mathrm{Dirichlet}(\alpha)$ and $X\mid p \sim \mathrm{Multinomial}(n,p)$,
$$
\mathbb{E}[X_i] = n\,\frac{\alpha_i}{\alpha_0}.
$$
### Covariance
For $i \ne j$:
$$
\mathrm{Cov}(X_i, X_j)
= -\,n\,\frac{\alpha_i\alpha_j}{\alpha_0^2}\,\frac{n+\alpha_0}{\alpha_0+1}.
$$
For the variance (the $i=j$ case):
$$
\mathrm{Var}(X_i)
= n\,\frac{\alpha_i}{\alpha_0}\left(1-\frac{\alpha_i}{\alpha_0}\right)\,\frac{n+\alpha_0}{\alpha_0+1}.
$$
### Marginals (Beta–binomial)
Each component $X_i$ marginally follows a Beta–binomial distribution:
$$
X_i \sim \mathrm{BetaBinomial}\big(n,\; \alpha_i,\; \alpha_0-\alpha_i\big).
$$
This is useful because it gives you **univariate** quantities like skewness and kurtosis for each component.
A clean way to get higher moments is via **factorial moments**. For $r\in\{1,2,3,\dots\}$:
$$
\mathbb{E}[(X_i)_{r}] = (n)_{r}\,\frac{(\alpha_i)_{r}}{(\alpha_0)_{r}},
$$
where $(a)_r$ is the rising factorial and $(X_i)_r$ on the left denotes the *falling* factorial.
From these you can reconstruct raw/central moments (and thus skewness/kurtosis).
### MGF / characteristic function
For a vector $t\in\mathbb{R}^K$, the MGF can be written as a (finite) sum over the support:
$$
M_X(t) = \mathbb{E}[e^{t^\top X}] = \sum_{x\in\mathcal{S}_{n,K}} e^{t^\top x}\,\Pr(X=x).
$$
Equivalently, via the mixture:
$$
M_X(t) = \mathbb{E}_{p\sim\mathrm{Dir}(\alpha)}\left[\left(\sum_{i=1}^K p_i e^{t_i}\right)^n\right].
$$
Closed forms involve special functions (multivariate hypergeometric functions). For small $n$ you can compute it by enumeration.
The characteristic function is $\varphi(\omega)=M_X(i\omega)$.
### Entropy
The Shannon entropy is
$$
H(X) = -\sum_{x\in\mathcal{S}_{n,K}} \Pr(X=x)\,\log \Pr(X=x).
$$
There is no simple closed form in general; you can compute it exactly by enumeration for small $n$, or estimate it by Monte Carlo.
def dm_mean(alpha, n: int) -> np.ndarray:
alpha = _validate_alpha(alpha)
return n * alpha / alpha.sum()
def dm_cov(alpha, n: int) -> np.ndarray:
alpha = _validate_alpha(alpha)
k = alpha.size
alpha0 = alpha.sum()
# Cov(X_i, X_j) = -n * alpha_i*alpha_j / alpha0^2 * (n+alpha0)/(alpha0+1)
factor = n * (n + alpha0) / (alpha0 + 1.0) / (alpha0**2)
cov = -factor * np.outer(alpha, alpha)
# Fix diagonal to variance formula
p = alpha / alpha0
var = n * p * (1.0 - p) * (n + alpha0) / (alpha0 + 1.0)
np.fill_diagonal(cov, var)
return cov
def beta_binomial_moments_via_factorials(n: int, a: float, b: float):
'''Return (mean, variance, skewness, excess_kurtosis) for BetaBinomial(n,a,b).
Uses factorial moments (stable + avoids a giant closed-form expression).
'''
a0 = a + b
# Falling factorial moments of X: E[(X)_r] = (n)_r E[p^r]
# with p~Beta(a,b), E[p^r] = (a)_r/(a0)_r (rising factorial).
f1 = n * a / a0
f2 = n * (n - 1) * a * (a + 1) / (a0 * (a0 + 1)) if n >= 2 else 0.0
f3 = (
n * (n - 1) * (n - 2) * a * (a + 1) * (a + 2) / (a0 * (a0 + 1) * (a0 + 2))
if n >= 3
else 0.0
)
f4 = (
n
* (n - 1)
* (n - 2)
* (n - 3)
* a
* (a + 1)
* (a + 2)
* (a + 3)
/ (a0 * (a0 + 1) * (a0 + 2) * (a0 + 3))
if n >= 4
else 0.0
)
# Stirling-number conversion (X^r = sum_k S(r,k) (X)_k)
m1 = f1
m2 = f1 + f2
m3 = f1 + 3 * f2 + f3
m4 = f1 + 7 * f2 + 6 * f3 + f4
mu = m1
mu2 = m2 - mu**2
mu3 = m3 - 3 * m2 * mu + 2 * mu**3
mu4 = m4 - 4 * m3 * mu + 6 * m2 * mu**2 - 3 * mu**4
skew = mu3 / (mu2 ** 1.5) if mu2 > 0 else np.nan
kurt_excess = mu4 / (mu2**2) - 3.0 if mu2 > 0 else np.nan
return mu, mu2, skew, kurt_excess
def dm_entropy_small_n(alpha, n: int) -> float:
alpha = _validate_alpha(alpha)
xs = enumerate_support(n=n, k=alpha.size)
logp = dirichlet_multinomial_logpmf(xs, alpha=alpha, n=n)
p = np.exp(logp)
return float(-np.sum(p * logp))
def dm_mgf_small_n(t, alpha, n: int) -> float:
alpha = _validate_alpha(alpha)
t = np.asarray(t, dtype=float)
if t.shape != alpha.shape:
raise ValueError(f"t must have shape {alpha.shape}")
xs = enumerate_support(n=n, k=alpha.size)
logp = dirichlet_multinomial_logpmf(xs, alpha=alpha, n=n)
return float(np.exp(logsumexp(logp + xs @ t)))
# Example: moments for a 3-category model
alpha = np.array([1.5, 2.0, 4.5])
n = 20
mean = dm_mean(alpha, n=n)
cov = dm_cov(alpha, n=n)
ent = dm_entropy_small_n(alpha, n=n)
mean, cov, ent
(array([ 3.75, 5. , 11.25]),
array([[ 9.4792, -2.9167, -6.5625],
[-2.9167, 11.6667, -8.75 ],
[-6.5625, -8.75 , 15.3125]]),
4.885255849427407)
# Marginal skewness/kurtosis via the Beta–binomial identity
alpha0 = alpha.sum()
for i in range(alpha.size):
a = alpha[i]
b = alpha0 - alpha[i]
m, v, s, kex = beta_binomial_moments_via_factorials(n=n, a=float(a), b=float(b))
print(
f"X_{i+1}: mean={m:.3f}, var={v:.3f}, skew={s:.3f}, excess_kurtosis={kex:.3f} "
f"(BetaBinomial(n={n}, a={a:.2f}, b={b:.2f}))"
)
# Cross-check against SciPy's betabinom for one component
i = 0
a = float(alpha[i])
b = float(alpha0 - alpha[i])
scipy_mean, scipy_var, scipy_skew, scipy_kex = stats.betabinom.stats(n=n, a=a, b=b, moments="mvsk")
(scipy_mean, scipy_var, scipy_skew, scipy_kex)
X_1: mean=3.750, var=9.479, skew=0.974, excess_kurtosis=0.711 (BetaBinomial(n=20, a=1.50, b=6.50))
X_2: mean=5.000, var=11.667, skew=0.703, excess_kurtosis=0.097 (BetaBinomial(n=20, a=2.00, b=6.00))
X_3: mean=11.250, var=15.312, skew=-0.153, excess_kurtosis=-0.537 (BetaBinomial(n=20, a=4.50, b=3.50))
(3.75, 9.479166666666666, 0.9743975315293802, 0.7108891108891111)
## 5) Parameter Interpretation
Think of $\alpha$ as **prior pseudo-counts** for category probabilities.
- The **mean proportions** are
$$\mathbb{E}[p_i] = \frac{\alpha_i}{\alpha_0},\quad \alpha_0=\sum_i\alpha_i.$$
- The **total concentration** $\alpha_0$ controls how much $p$ varies across replicates:
- small $\alpha_0$ → $p$ is highly variable → counts are **more dispersed** (more mass near simplex corners)
- large $\alpha_0$ → $p$ concentrates near its mean → counts look more like a plain multinomial
- Holding $\alpha_0$ fixed, changing the **ratios** $\alpha_i/\alpha_0$ shifts mass toward categories with larger ratios.
Below we visualize samples for the same mean proportions but different total concentration.
# Same mean proportions, different concentration alpha0
n = 25
base = np.array([1.0, 2.0, 3.0])
base = base / base.sum() # mean proportions
scales = [0.3, 1.0, 5.0]
size = 2500
fig = go.Figure()
# draw simplex triangle
tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(
go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False)
)
for s in scales:
alpha_s = s * base * 30.0 # scale into a reasonable pseudo-count regime
samples = dirichlet_multinomial_rvs_numpy(alpha=alpha_s, n=n, size=size, rng=rng)
x, y = simplex_xy_3(samples)
fig.add_trace(
go.Scattergl(
x=x,
y=y,
mode="markers",
name=f"alpha0≈{alpha_s.sum():.1f}",
marker=dict(size=4, opacity=0.25),
)
)
fig.update_layout(
title="Dirichlet–multinomial samples on the 3-simplex (same mean, different concentration)",
xaxis_title="barycentric x",
yaxis_title="barycentric y",
xaxis=dict(scaleanchor="y", scaleratio=1),
width=850,
height=500,
)
fig.show()
## 6) Derivations
We derive mean and covariance using the mixture representation:
$$
p \sim \mathrm{Dirichlet}(\alpha),
\qquad
X \mid p \sim \mathrm{Multinomial}(n, p).
$$
### Expectation
By the law of total expectation:
$$
\mathbb{E}[X_i] = \mathbb{E}\big[\,\mathbb{E}[X_i\mid p]\,\big]
= \mathbb{E}[n p_i]
= n\,\mathbb{E}[p_i]
= n\,\frac{\alpha_i}{\alpha_0}.
$$
### Variance
By the law of total variance:
$$
\mathrm{Var}(X_i) = \mathbb{E}[\mathrm{Var}(X_i\mid p)] + \mathrm{Var}(\mathbb{E}[X_i\mid p]).
$$
For a multinomial:
$$
\mathbb{E}[X_i\mid p] = n p_i,
\qquad
\mathrm{Var}(X_i\mid p) = n p_i(1-p_i).
$$
So:
$$
\mathrm{Var}(X_i)
= \mathbb{E}[n p_i(1-p_i)] + \mathrm{Var}(n p_i)
= n\,\mathbb{E}[p_i - p_i^2] + n^2\,\mathrm{Var}(p_i).
$$
Using Dirichlet moments
$\mathbb{E}[p_i]=\alpha_i/\alpha_0$ and $\mathrm{Var}(p_i)=\alpha_i(\alpha_0-\alpha_i)/(\alpha_0^2(\alpha_0+1))$
yields the variance formula in Section 4.
### Covariance
Similarly, for $i\ne j$:
$$
\mathrm{Cov}(X_i, X_j)
= \mathbb{E}[\mathrm{Cov}(X_i,X_j\mid p)] + \mathrm{Cov}(\mathbb{E}[X_i\mid p],\mathbb{E}[X_j\mid p]).
$$
For a multinomial, $\mathrm{Cov}(X_i,X_j\mid p) = -n p_i p_j$ for $i\ne j$.
With Dirichlet moments for $\mathbb{E}[p_i p_j]$, you arrive at the negative covariance formula.
### Likelihood (for fitting $\alpha$)
Given an observed count vector $x$ with total $n$, the likelihood as a function of $\alpha$ is:
$$
L(\alpha; x)
\propto
\frac{\Gamma(\alpha_0)}{\Gamma(\alpha_0+n)}
\prod_{i=1}^K \frac{\Gamma(\alpha_i + x_i)}{\Gamma(\alpha_i)}.
$$
Taking logs gives:
$$
\ell(\alpha; x)
= \log\Gamma(\alpha_0) - \log\Gamma(\alpha_0+n)
+ \sum_i \big(\log\Gamma(\alpha_i + x_i) - \log\Gamma(\alpha_i)\big)
+ \text{const}(x).
$$
There is no closed-form MLE for $\alpha$ in general; you typically optimize $\ell(\alpha)$ numerically.
## 7) Sampling & Simulation
A simple **NumPy-only** sampling algorithm follows directly from the hierarchical story:
1. Sample $p \sim \mathrm{Dirichlet}(\alpha)$.
A standard implementation uses Gamma variables: draw $g_i \sim \mathrm{Gamma}(\alpha_i, 1)$ and set $p_i = g_i / \sum_j g_j$.
2. Sample counts $X \mid p \sim \mathrm{Multinomial}(n, p)$.
This is exactly what `dirichlet_multinomial_rvs_numpy` implements.
Below we verify mean/covariance by Monte Carlo.
alpha = np.array([1.5, 2.0, 4.5])
n = 25
theory_mean = dm_mean(alpha, n=n)
theory_cov = dm_cov(alpha, n=n)
samples = dirichlet_multinomial_rvs_numpy(alpha=alpha, n=n, size=50_000, rng=rng)
sample_mean = samples.mean(axis=0)
sample_cov = np.cov(samples.T, ddof=0)
print('theory mean:', theory_mean)
print('sample mean:', sample_mean)
print('
max abs mean error:', np.max(np.abs(sample_mean - theory_mean)))
print('
max abs cov error:', np.max(np.abs(sample_cov - theory_cov)))
Cell In[7], line 2
n = 25
^
IndentationError: unexpected indent
## 8) Visualization
Because the Dirichlet–multinomial is multivariate, visuals depend on $K$:
- For $K=2$ it reduces to a **Beta–binomial** and you can plot a standard PMF/CDF over $\{0,1,\dots,n\}$.
- For $K=3$ you can plot probabilities/samples on a 2D simplex (triangle).
We do both.
# K=2: PMF and CDF (Beta–binomial view)
n = 30
alpha2 = np.array([2.0, 5.0])
xs = np.arange(n + 1)
pmf = np.array([dirichlet_multinomial_pmf([x, n - x], alpha=alpha2, n=n) for x in xs])
cdf = np.cumsum(pmf)
fig = go.Figure()
fig.add_trace(go.Bar(x=xs, y=pmf, name="PMF"))
fig.update_layout(
title="Dirichlet–multinomial with K=2 (PMF of X1)",
xaxis_title="x",
yaxis_title="P(X1=x)",
bargap=0.05,
width=850,
height=380,
)
fig.show()
fig = go.Figure()
fig.add_trace(go.Scatter(x=xs, y=cdf, mode="lines+markers", name="CDF"))
fig.update_layout(
title="Dirichlet–multinomial with K=2 (CDF of X1)",
xaxis_title="x",
yaxis_title="P(X1≤x)",
width=850,
height=380,
)
fig.show()
# Monte Carlo samples vs PMF (K=2)
size = 20_000
s = dirichlet_multinomial_rvs_numpy(alpha=alpha2, n=n, size=size, rng=rng)
x1 = s[:, 0]
fig = px.histogram(x1, nbins=n + 1, histnorm="probability", title="Monte Carlo histogram vs PMF")
fig.add_trace(go.Scatter(x=xs, y=pmf, mode="lines", name="PMF", line=dict(color="black")))
fig.update_layout(xaxis_title="x1", yaxis_title="probability", width=850, height=420)
fig.show()
# K=3: PMF on the simplex (small n, exact enumeration)
n = 20
alpha3 = np.array([1.2, 2.5, 4.0])
support = enumerate_support(n=n, k=3)
logp = dirichlet_multinomial_logpmf(support, alpha=alpha3, n=n)
p = np.exp(logp)
sx, sy = simplex_xy_3(support)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=sx,
y=sy,
mode="markers",
marker=dict(
size=10,
color=np.log10(p),
colorscale="Viridis",
colorbar=dict(title="log10 PMF"),
),
text=[str(tuple(row)) for row in support],
hovertemplate="x=%{text}<br>log10 p=%{marker.color:.3f}<extra></extra>",
name="support",
)
)
tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False))
fig.update_layout(
title="Dirichlet–multinomial PMF on the 3-simplex (exact, enumerated support)",
xaxis_title="barycentric x",
yaxis_title="barycentric y",
xaxis=dict(scaleanchor="y", scaleratio=1),
width=850,
height=520,
)
fig.show()
# K=3: Monte Carlo samples on the simplex
n = 20
alpha3 = np.array([1.2, 2.5, 4.0])
samples = dirichlet_multinomial_rvs_numpy(alpha=alpha3, n=n, size=6000, rng=rng)
x, y = simplex_xy_3(samples)
fig = go.Figure()
fig.add_trace(
go.Scattergl(
x=x,
y=y,
mode="markers",
marker=dict(size=4, opacity=0.25),
text=[str(tuple(row)) for row in samples],
hovertemplate="x=%{text}<extra></extra>",
name="samples",
)
)
tri_x = [0.0, 1.0, 0.5, 0.0]
tri_y = [0.0, 0.0, np.sqrt(3) / 2.0, 0.0]
fig.add_trace(go.Scatter(x=tri_x, y=tri_y, mode="lines", line=dict(color="black"), showlegend=False))
fig.update_layout(
title="Dirichlet–multinomial Monte Carlo samples on the 3-simplex",
xaxis_title="barycentric x",
yaxis_title="barycentric y",
xaxis=dict(scaleanchor="y", scaleratio=1),
width=850,
height=520,
)
fig.show()
## 9) SciPy Integration
SciPy exposes the Dirichlet–multinomial as `scipy.stats.dirichlet_multinomial`.
In this environment (SciPy version may vary), the object provides:
- `pmf` / `logpmf`
- moment methods like `mean`, `var`, and `cov`
But it may **not** provide `cdf`, `rvs`, or `fit` for this multivariate distribution.
We’ll show:
- how to use SciPy where available
- how to implement missing pieces (CDF by summation for small $n$, sampling via the hierarchical model, and MLE via `scipy.optimize`)
import scipy
print('SciPy version:', scipy.__version__)
n = 10
alpha = np.array([1.0, 2.0, 3.0])
x = np.array([2, 3, 5])
dm = stats.dirichlet_multinomial(n=n, alpha=alpha)
print('pmf:', dm.pmf(x))
print('logpmf:', dm.logpmf(x))
print('mean:', dm.mean())
print('cov:
', dm.cov())
# Feature check
for name in ['cdf', 'rvs', 'fit']:
print(name, 'available?', hasattr(dm, name))
Cell In[12], line 3
print('SciPy version:', scipy.__version__)
^
IndentationError: unexpected indent
# CDF: SciPy doesn't implement a multivariate cdf here, but we can compute it by brute force for small n.
n = 12
alpha = np.array([1.0, 2.0, 3.0])
x = np.array([3, 4, 5])
dm_cdf_small_n(x, alpha=alpha, n=n)
0.016968325791855168
# For K=2, the CDF reduces to the usual univariate Beta–binomial CDF
n = 30
alpha2 = np.array([2.0, 5.0])
xs = np.arange(n + 1)
# X1 ~ BetaBinomial(n, a=alpha1, b=alpha2)
cdf_scipy = stats.betabinom.cdf(xs, n=n, a=alpha2[0], b=alpha2[1])
cdf_numpy = np.cumsum([dirichlet_multinomial_pmf([x, n - x], alpha=alpha2, n=n) for x in xs])
float(np.max(np.abs(cdf_scipy - cdf_numpy)))
2.55351295663786e-15
# Sampling: SciPy's dirichlet_multinomial may not expose rvs, but sampling is easy via the hierarchical model.
# Here is a SciPy-flavored sampler (Dirichlet from SciPy + Multinomial from NumPy):
def dirichlet_multinomial_rvs_scipy(alpha, n: int, size: int, rng: np.random.Generator) -> np.ndarray:
alpha = _validate_alpha(alpha)
ps = stats.dirichlet.rvs(alpha, size=size, random_state=rng)
out = np.empty((size, alpha.size), dtype=int)
for i, p in enumerate(ps):
out[i] = rng.multinomial(n, p)
return out
alpha = np.array([1.0, 2.0, 3.0])
n = 10
samples = dirichlet_multinomial_rvs_scipy(alpha=alpha, n=n, size=5, rng=rng)
samples
array([[1, 1, 8],
[4, 2, 4],
[0, 3, 7],
[0, 5, 5],
[2, 2, 6]])
# Fit (MLE): optimize the Dirichlet–multinomial log-likelihood for alpha
def dm_loglik(alpha, X: np.ndarray) -> float:
alpha = _validate_alpha(alpha)
X = _validate_counts(X, k=alpha.size)
n_vec = X.sum(axis=1)
alpha0 = alpha.sum()
# omit multinomial coefficient terms (constants wrt alpha)
ll = (
X.shape[0] * gammaln(alpha0)
- np.sum(gammaln(alpha0 + n_vec))
+ np.sum(gammaln(alpha + X) - gammaln(alpha), axis=1).sum()
)
return float(ll)
def dm_loglik_grad(alpha, X: np.ndarray) -> np.ndarray:
alpha = _validate_alpha(alpha)
X = _validate_counts(X, k=alpha.size)
n_vec = X.sum(axis=1)
alpha0 = alpha.sum()
m = X.shape[0]
common = m * digamma(alpha0) - np.sum(digamma(alpha0 + n_vec))
grad = common + np.sum(digamma(alpha + X), axis=0) - m * digamma(alpha)
return grad
def dm_mom_alpha_init(X: np.ndarray) -> np.ndarray:
'''Method-of-moments-ish initializer for alpha (works best when n is constant).'''
X = np.asarray(X, dtype=float)
X = np.atleast_2d(X)
n_vec = X.sum(axis=1)
if not np.allclose(n_vec, n_vec[0]):
# fall back: mean proportions with moderate concentration
p_hat = X.sum(axis=0) / X.sum()
return 20.0 * p_hat
n = float(n_vec[0])
p_hat = X.mean(axis=0) / n
s2 = X.var(axis=0, ddof=0)
# v_i ≈ Var / (n p(1-p)) = (n+alpha0)/(alpha0+1)
denom = n * p_hat * (1.0 - p_hat)
usable = denom > 1e-12
v = np.median((s2[usable] / denom[usable]).clip(min=1.0)) if np.any(usable) else 1.0
if v <= 1.0 + 1e-8:
alpha0 = 1e3
else:
alpha0 = (n - v) / (v - 1.0)
alpha0 = float(np.clip(alpha0, 1e-3, 1e4))
return alpha0 * p_hat
def fit_dirichlet_multinomial_mle(X: np.ndarray, alpha_init: np.ndarray | None = None) -> np.ndarray:
X = _validate_counts(X, k=np.asarray(X).shape[-1])
k = X.shape[1]
if alpha_init is None:
alpha_init = dm_mom_alpha_init(X)
alpha_init = np.asarray(alpha_init, dtype=float)
if alpha_init.shape != (k,):
raise ValueError(f"alpha_init must have shape ({k},)")
# optimize over log(alpha) to enforce positivity
x0 = np.log(alpha_init)
bounds = [(-10.0, 10.0)] * k # keeps alpha in a safe numeric range
def obj(log_alpha):
a = np.exp(log_alpha)
return -dm_loglik(a, X)
def grad(log_alpha):
a = np.exp(log_alpha)
return -(dm_loglik_grad(a, X) * a) # chain rule
res = minimize(obj, x0=x0, jac=grad, method='L-BFGS-B', bounds=bounds)
if not res.success:
raise RuntimeError(f"MLE optimization failed: {res.message}")
return np.exp(res.x)
# Demo: simulate + fit
rng_fit = np.random.default_rng(0)
alpha_true = np.array([1.2, 2.5, 4.0])
n = 20
m = 300
X = dirichlet_multinomial_rvs_numpy(alpha=alpha_true, n=n, size=m, rng=rng_fit)
alpha_hat = fit_dirichlet_multinomial_mle(X)
alpha_true, alpha_hat, alpha_hat / alpha_hat.sum()
(array([1.2, 2.5, 4. ]),
array([1.0448, 2.4962, 3.8207]),
array([0.1419, 0.3391, 0.519 ]))
## 10) Statistical Use Cases
### 1) Hypothesis testing: multinomial vs overdispersed counts
A common question is whether a plain multinomial is *too restrictive*.
You can compare:
- **$H_0$**: $X \sim \mathrm{Multinomial}(n, p)$ (fixed $p$)
- **$H_1$**: $X \sim \mathrm{DirichletMultinomial}(n, \alpha)$ (random $p$)
A likelihood ratio statistic can be used, but the usual $\chi^2$ reference is unreliable because the multinomial is a boundary case (roughly $\alpha_0\to\infty$).
A practical approach is a **parametric bootstrap** under $H_0$.
### 2) Bayesian modeling: posterior predictive
If you place a Dirichlet prior on multinomial probabilities, the posterior is Dirichlet and the **posterior predictive** for new counts is Dirichlet–multinomial.
### 3) Generative modeling
Dirichlet–multinomial is a natural “bag-of-words” generator: it samples a document-level word distribution and then generates word counts.
# Bayesian modeling: Dirichlet posterior + Dirichlet–multinomial posterior predictive
alpha_prior = np.array([1.0, 1.0, 1.0])
x_obs = np.array([4, 1, 5])
alpha_post = alpha_prior + x_obs
print('prior mean p:', alpha_prior / alpha_prior.sum())
print('posterior mean p:', alpha_post / alpha_post.sum())
# Posterior predictive for future n_new counts
n_new = 12
x_future = np.array([3, 5, 4])
p_pred = dirichlet_multinomial_pmf(x_future, alpha=alpha_post, n=n_new)
p_pred
prior mean p: [0.3333 0.3333 0.3333]
posterior mean p: [0.3846 0.1538 0.4615]
0.009784938442900492
# Hypothesis testing demo: parametric bootstrap LRT (small example)
def multinomial_loglik(X: np.ndarray, p: np.ndarray) -> float:
X = _validate_counts(X, k=p.size)
p = np.asarray(p, dtype=float)
if p.ndim != 1 or p.size != X.shape[1]:
raise ValueError('p must be shape (k,)')
if np.any(p <= 0):
raise ValueError('p must be strictly positive (use smoothing if needed)')
p = p / p.sum()
n_vec = X.sum(axis=1)
ll = (
gammaln(n_vec + 1)
- np.sum(gammaln(X + 1), axis=1)
+ (X * np.log(p)).sum(axis=1)
).sum()
return float(ll)
def lrt_statistic(X: np.ndarray) -> tuple[float, np.ndarray, np.ndarray]:
X = _validate_counts(X, k=np.asarray(X).shape[-1])
n_vec = X.sum(axis=1)
if not np.all(n_vec == n_vec[0]):
raise ValueError('This demo assumes constant n across rows')
# H0: multinomial MLE for p
p_hat = X.sum(axis=0) / X.sum()
p_hat = (p_hat + 1e-12) / (p_hat.sum() + 1e-12 * p_hat.size)
ll0 = multinomial_loglik(X, p_hat)
# H1: Dirichlet–multinomial MLE for alpha
alpha_hat = fit_dirichlet_multinomial_mle(X)
ll1 = dm_loglik(alpha_hat, X)
return 2.0 * (ll1 - ll0), p_hat, alpha_hat
rng_test = np.random.default_rng(123)
# Simulate an overdispersed dataset under H1
alpha_true = np.array([1.2, 2.5, 4.0])
n = 20
m = 80
X = dirichlet_multinomial_rvs_numpy(alpha=alpha_true, n=n, size=m, rng=rng_test)
lrt_obs, p_hat_obs, alpha_hat_obs = lrt_statistic(X)
print('Observed LRT:', lrt_obs)
print('alpha_hat:', alpha_hat_obs)
# Bootstrap under H0 (multinomial)
B = 30
lrt_boot = []
for _ in range(B):
Xb = rng_test.multinomial(n, p_hat_obs, size=m)
stat, _, _ = lrt_statistic(Xb)
lrt_boot.append(stat)
lrt_boot = np.array(lrt_boot)
p_value = float(np.mean(lrt_boot >= lrt_obs))
print('bootstrap LRT mean:', lrt_boot.mean())
print('bootstrap p-value (rough, small B):', p_value)
Observed LRT: -2054.403206622123
alpha_hat: [1.2211 2.5888 3.5766]
bootstrap LRT mean: -2586.3898742580986
bootstrap p-value (rough, small B): 0.0
# Generative modeling example: "documents" as category-count vectors
alpha_topic = np.array([0.4, 0.4, 0.4]) # sparse-ish p for each document
n_words = 60
n_docs = 200
docs = dirichlet_multinomial_rvs_numpy(alpha=alpha_topic, n=n_words, size=n_docs, rng=rng)
# Visualize document-level proportions
props = docs / docs.sum(axis=1, keepdims=True)
fig = px.scatter_3d(
x=props[:, 0], y=props[:, 1], z=props[:, 2],
title="Document-level proportions (Dirichlet–multinomial generator)",
labels={'x': 'p1', 'y': 'p2', 'z': 'p3'}
)
fig.update_traces(marker=dict(size=3, opacity=0.6))
fig.show()
## 11) Pitfalls
- **Invalid parameters**:
- $\alpha_i$ must be strictly positive.
- $x_i$ must be nonnegative integers and must satisfy $\sum_i x_i = n$.
- **Numerical issues**:
- PMFs can underflow quickly when $n$ is large. Prefer `logpmf` and compute in log-space.
- Use `gammaln` / `digamma` rather than `gamma` / factorials.
- **Combinatorial explosion**:
- The support size is $\binom{n+K-1}{K-1}$.
- Exact enumeration (for entropy, CDF, full PMF plots) is only feasible for small $n$ and moderate $K$.
- **Fitting**:
- The multinomial is a limiting case ($\alpha_0\to\infty$). In near-multinomial data, MLE may push $\alpha$ very large.
- Use good initialization and consider bounds / regularization if optimization is unstable.
## 12) Summary
- The Dirichlet–multinomial is the **posterior predictive** distribution for multinomial counts with a Dirichlet prior.
- It models **overdispersed** multinomial counts by letting the category probabilities vary across replicates.
- Mean proportions are $\alpha/\alpha_0$; total concentration $\alpha_0$ controls dispersion.
- PMF evaluation is stable in log-space via Gamma functions.
- Exact CDF/entropy require summation over a combinatorial support; for larger problems use Monte Carlo or approximations.